import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
from recbole_cdr.model.cross_domain_recommender.graphcdr import GraphCDR
from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import EmbLoss
from recbole.utils import InputType
import dgl
import dgl.function as fn
from dgl.nn import GraphConv,GATConv,SAGEConv
from hyperbolic_gnn.model.hgcn.layers.gnn import GAT,GCN,GATv2,GraphSAGE
from hyperbolic_gnn.model.hgcn.layers.euclidean_contrastive_learning import GraphContrastive

class DGLGNN(GraphCDR):
    input_type = InputType.POINTWISE
    def __init__(self, config, dataset):
        super(DGLGNN, self).__init__(config, dataset)
        # load dataset info
        self.config=config
        self.SOURCE_LABEL = dataset.source_domain_dataset.label_field
        self.TARGET_LABEL = dataset.target_domain_dataset.label_field
        self.device = config['device']
        self.latent_dim = config['embedding_size']
        self.n_layers = 2
        self.reg_weight = config['reg_weight']
        if config['dglconv']=='GraphConv':
           self.layers = nn.ModuleList([
            GraphConv(self.latent_dim,
                      self.latent_dim,
                      norm='both',
                      weight=True,
                      bias=True,
                      activation=None
                      ) for _ in range(self.n_layers)
        ]).to(self.device)
        elif config['dglconv']=='GATConv':
            self.layers = nn.ModuleList([
                GATConv(self.latent_dim, self.latent_dim, num_heads=config['num_heads']) for _ in
                range(self.n_layers)
            ]).to(self.device)
        elif config['dglconv'] == 'SAGE':
            self.layers = nn.ModuleList([
                SAGEConv(self.latent_dim, self.latent_dim, 'mean') for _ in
                range(self.n_layers).to(self.device)
            ])
        elif config['dglconv'] == 'LightGCN':
            self.layers = nn.ModuleList([
                GraphConv(self.latent_dim,
                          self.latent_dim,
                          norm='both',
                          weight=False,
                          bias=False,
                          activation=None,
                          allow_zero_in_degree=True) for _ in range(self.n_layers)]).to(self.device)
        # 建一个总的图
        #self.source_u, self.source_i,self.target_u, self.target_i
        self.dropout = nn.Dropout(p=self.drop_rate)
        self.loss = nn.BCELoss()
        self.sigmoid = nn.Sigmoid()
        self.reg_loss = EmbLoss()
        self.target_restore_user_e = None
        self.target_restore_item_e = None
        self.apply(xavier_normal_initialization)
        self.other_parameter_name = ['target_restore_user_e', 'target_restore_item_e']
        self.merge_dgl_graph=self.generate_dgl_graph().to(self.device)
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=self.latent_dim).to(self.device)

        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=self.latent_dim).to(self.device)
    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings,item_embeddings], dim=0).to(self.device)
        return ego_embeddings
    def generate_dgl_graph(self):
        all_users = torch.cat([self.target_u,self.source_u], dim=0)
        all_items = torch.cat([self.target_i,self.source_i], dim=0)
        num_users = self.total_num_users
        num_items = self.total_num_items
        all_items = all_items + num_users
        num_nodes_total = num_items + num_users
        ############################
        # 3. 构造 DGL图, 设置边权
        ############################
        g = dgl.graph((all_users, all_items), num_nodes=num_nodes_total)
        # 转换为无向图
        g = dgl.to_bidirected(g)
        if self.config['dglconv']!='LightGCN':
           g = dgl.add_self_loop(g)
        edge_weights = torch.ones((g.num_edges(), 1), dtype=torch.float32)
        g.edata['w'] = edge_weights
        return g
    def forward(self,graph):
        """
        return: [num_nodes, emb_dim] 最终节点向量
        """
        h = self.get_ego_embeddings() # [num_nodes, emb_dim]
        if self.config['dglconv']=='LightGCN':
           hs = [h]
           for layer in self.layers:
               h = layer(graph.to(self.device),
                         h,
                         edge_weight=graph.edata['w'].to(self.device)
                         )
               hs.append(h)
           # LightGCN 通常对各层输出(含层0)做mean
           out = torch.stack(hs, dim=0).mean(dim=0)
           user_all_embeddings, item_all_embeddings = torch.split(out,
                                                                  [self.total_num_users,
                                                                   self.total_num_items])
        else:
            for layer in self.layers:
                h = layer(graph.to(self.device), h, edge_weight=graph.edata['w'].to(self.device))
            user_all_embeddings, item_all_embeddings = torch.split(h,
                                                               [self.total_num_users,
                                                                self.total_num_items])
        return user_all_embeddings, item_all_embeddings
    def calculate_loss(self, interaction):
        self.init_restore_e()
        user_all_embeddings, item_all_embeddings = self.forward(self.merge_dgl_graph)
        losses = []
        source_user = interaction[self.SOURCE_USER_ID]
        source_item = interaction[self.SOURCE_ITEM_ID]
        source_label = interaction[self.SOURCE_LABEL]
        target_user = interaction[self.TARGET_USER_ID]
        target_item = interaction[self.TARGET_ITEM_ID]
        target_label = interaction[self.TARGET_LABEL]
        source_u_embeddings = user_all_embeddings[source_user]
        source_i_embeddings = item_all_embeddings[source_item]
        target_u_embeddings = user_all_embeddings[target_user]
        target_i_embeddings = item_all_embeddings[target_item]
        if self.config['setting']=='transfer':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            # calculate Reg Loss in source domain
            u_source_ego_embeddings = self.user_embedding(source_user)
            i_source_ego_embeddings = self.item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            losses = source_bce_loss + self.reg_weight * source_reg_loss
            return losses
        elif self.config['setting']=='merge':
            source_output = self.sigmoid(torch.mul(source_u_embeddings, source_i_embeddings).sum(dim=1))
            source_bce_loss = self.loss(source_output, source_label)
            u_source_ego_embeddings = self.user_embedding(source_user)
            i_source_ego_embeddings = self.item_embedding(source_item)
            source_reg_loss = self.reg_loss(u_source_ego_embeddings, i_source_ego_embeddings)
            source_loss = source_bce_loss + self.reg_weight * source_reg_loss
            losses.append(source_loss)
            # calculate BCE Loss in target domain
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.user_embedding(target_user)
            i_target_ego_embeddings = self.item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            target_loss = target_bce_loss + self.reg_weight * target_reg_loss
            losses.append(target_loss)
            return tuple(losses)
        elif self.config['setting']=='source_passing':
            target_output = self.sigmoid(torch.mul(target_u_embeddings, target_i_embeddings).sum(dim=1))
            target_bce_loss = self.loss(target_output, target_label)
            # calculate Reg Loss in target domain
            u_target_ego_embeddings = self.target_user_embedding(target_user)
            i_target_ego_embeddings = self.target_item_embedding(target_item)
            target_reg_loss = self.reg_loss(u_target_ego_embeddings, i_target_ego_embeddings)
            losses = target_bce_loss + self.reg_weight * target_reg_loss
            return losses
    def full_sort_predict(self, interaction):
        user = interaction[self.TARGET_USER_ID]
        restore_user_e, restore_item_e = self.get_restore_e()
        u_embeddings = restore_user_e[user]
        i_embeddings = restore_item_e[:self.target_num_items]
        scores = torch.matmul(u_embeddings, i_embeddings.transpose(0, 1))
        return scores.view(-1)
    def init_restore_e(self):
        # clear the storage variable when training
        if self.target_restore_user_e is not None or self.target_restore_item_e is not None:
            self.target_restore_user_e, self.target_restore_item_e = None, None
    def get_restore_e(self):
        if self.target_restore_user_e is None or self.target_restore_item_e is None:
            self.target_restore_user_e, self.target_restore_item_e = self.forward(self.merge_dgl_graph)
        return self.target_restore_user_e, self.target_restore_item_e